{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demo 1~3\n",
    "You can find a simple version in `elegantrl/tutorial/run.py`\n",
    "You can also find demo 1~3 in `elegantrl/run.py` (formal version)\n",
    "\n",
    "elegantrl/tutorial <1000 lines \n",
    "```\n",
    "net.py   # 160 lines\n",
    "agent.py # 530 lines\n",
    "run.py   # 320 lines\n",
    "env.py   # 160 lines (not necessary)\n",
    "```\n",
    "The structtion of formal version is similar to tutorial version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from elegantrl.tutorial.run import Arguments, train_and_evaluate\n",
    "from elegantrl.tutorial.env import PreprocessEnv\n",
    "import gym\n",
    "gym.logger.set_level(40)  # Block warning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demo 1: Discrete action space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''choose an DRL algorithm'''\n",
    "from elegantrl.tutorial.agent import AgentDoubleDQN  # AgentDQN\n",
    "\n",
    "args = Arguments(agent=None, env=None, gpu_id=None)\n",
    "args.agent = AgentDoubleDQN()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "| env_name: CartPole-v0, action space if_discrete: True\n",
      "| state_dim: 4, action_dim: 2, action_max: 1\n",
      "| max_step: 200 target_reward: 195.0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'TotalStep: 2e3, TargetReward: , UsedTime: 10s'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''choose environment'''\n",
    "args.env = PreprocessEnv(env=gym.make('CartPole-v0'))\n",
    "args.net_dim = 2 ** 7  # change a default hyper-parameters\n",
    "args.batch_size = 2 ** 7\n",
    "\"TotalStep: 2e3, TargetReward: , UsedTime: 10s\"\n",
    "\n",
    "# args.env = PreprocessEnv(env=gym.make('LunarLander-v2'))\n",
    "# args.net_dim = 2 ** 8\n",
    "# args.batch_size = 2 ** 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| GPU id: 0, cwd: ./CartPole-v0_0\n",
      "| Remove history\n",
      "ID      Step      MaxR |    avgR      stdR       objA      objC\n",
      "0   0.00e+00    200.00 |\n",
      "ID      Step   TargetR |    avgR      stdR   UsedTime  ########\n",
      "0   1.02e+03    195.00 |  200.00      0.00         12  ########\n"
     ]
    }
   ],
   "source": [
    "'''train and evaluate'''\n",
    "train_and_evaluate(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demo 2: Continuous action space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''DEMO 2.1: choose an off-policy DRL algorithm'''\n",
    "from elegantrl.agent import AgentSAC  # AgentTD3, AgentDDPG\n",
    "args = Arguments(if_on_policy=False)\n",
    "args.agent = AgentSAC()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''DEMO 2.2: choose an on-policy DRL algorithm'''\n",
    "from elegantrl.tutorial.agent import AgentPPO \n",
    "args = Arguments(if_on_policy=True)  # hyper-parameters of on-policy is different from off-policy\n",
    "args.agent = AgentPPO()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "| env_name: Pendulum-v0, action space if_discrete: False\n",
      "| state_dim: 3, action_dim: 1, action_max: 2.0\n",
      "| max_step: 200 target_reward: -200\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'TotalStep: 2e5, TargetReward: 300, UsedTime: 5000s'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''choose environment'''\n",
    "env = gym.make('Pendulum-v0')\n",
    "env.target_reward = -200  # set target_reward manually for env 'Pendulum-v0'\n",
    "args.env = PreprocessEnv(env=env)\n",
    "args.reward_scale = 2 ** -3  # RewardRange: -1800 < -200 < -50 < 0\n",
    "args.net_dim = 2 ** 7\n",
    "args.batch_size = 2 ** 7\n",
    "\"TotalStep: 3e5, TargetReward: -200, UsedTime: 300s\"\n",
    "# args.env = PreprocessEnv(env=gym.make('LunarLanderContinuous-v2'))\n",
    "# args.reward_scale = 2 ** 0  # RewardRange: -800 < -200 < 200 < 302\n",
    "# \"TotalStep: 9e4, TargetReward: 200, UsedTime: 2500s\"\n",
    "# args.env = PreprocessEnv(env=gym.make('BipedalWalker-v3'))\n",
    "# args.reward_scale = 2 ** 0  # RewardRange: -200 < -150 < 300 < 334\n",
    "# args.break_step = int(2e5)\n",
    "# args.if_allow_break = False\n",
    "# \"TotalStep: 2e5, TargetReward: 300, UsedTime: 5000s\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| GPU id: 0, cwd: ./Pendulum-v0_0\n",
      "| Remove history\n",
      "ID      Step      MaxR |    avgR      stdR       objA      objC\n",
      "0   0.00e+00  -1520.74 |\n",
      "0   4.20e+03  -1432.98 |\n",
      "0   8.40e+03  -1227.07 |\n",
      "0   1.68e+04  -1185.30 |\n",
      "0   2.10e+04  -1124.60 |\n",
      "0   2.94e+04   -940.84 |\n",
      "0   4.20e+04   -877.70 |\n",
      "0   5.46e+04   -862.52 |\n",
      "0   7.14e+04   -746.69 |\n",
      "0   9.24e+04   -620.06 |\n",
      "0   9.66e+04   -564.93 |\n",
      "0   1.09e+05   -559.21 |\n",
      "0   1.13e+05   -396.94 |\n",
      "0   1.22e+05   -250.43 |\n",
      "0   1.39e+05   -116.70 |\n",
      "ID      Step   TargetR |    avgR      stdR   UsedTime  ########\n",
      "0   1.43e+05   -200.00 | -116.70    116.00        225  ########\n"
     ]
    }
   ],
   "source": [
    "'''train and evaluate'''\n",
    "train_and_evaluate(args)\n",
    "# train_and_evaluate__multiprocessing(args)  # try multiprocessing in formal version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demo 3: Custom Env from AI4Finance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'TotalStep: 10e5, TargetReward: 1.62, UsedTime: 1000s'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args = Arguments(if_on_policy=True)\n",
    "'''choose an DRL algorithm'''\n",
    "from elegantrl.tutorial.agent import AgentPPO\n",
    "args.agent = AgentPPO()\n",
    "\n",
    "from elegantrl.tutorial.env import FinanceMultiStockEnv  # a standard env for ElegantRL, not need PreprocessEnv()\n",
    "args.env = FinanceMultiStockEnv(if_train=True)\n",
    "args.env_eval = FinanceMultiStockEnv(if_train=False)  # eva_len = 1699 - train_len\n",
    "args.reward_scale = 2 ** 0  # RewardRange: 0 < 1.0 < 1.25 <\n",
    "args.break_step = int(5e6)\n",
    "args.max_step = args.env.max_step\n",
    "args.max_memo = (args.max_step - 1) * 8\n",
    "args.batch_size = 2 ** 11\n",
    "\"TotalStep:  2e5, TargetReward: 1.25, UsedTime:  200s\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| GPU id: 0, cwd: ./FinanceStock-v2_0\n",
      "| Remove history\n",
      "ID      Step      MaxR |    avgR      stdR       objA      objC\n",
      "0   0.00e+00      1.06 |\n",
      "0   5.12e+03      1.10 |\n",
      "0   1.02e+04      1.22 |\n",
      "0   2.56e+04      1.29 |\n",
      "ID      Step   TargetR |    avgR      stdR   UsedTime  ########\n",
      "0   3.07e+04      1.25 |    1.29      0.03         29  ########\n"
     ]
    }
   ],
   "source": [
    "'''train and evaluate'''\n",
    "train_and_evaluate(args)\n",
    "# args.rollout_num = 8\n",
    "# train_and_evaluate__multiprocessing(args)  # try multiprocessing in formal version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demo 4: train in PyBullet (MuJoCo) (wait for adding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym  # don't worry about 'WARN: Box bound precision lowered by casting to float32'\n",
    "import pybullet_envs  # PyBullet is free, but MuJoCo is paid\n",
    "from AgentEnv import decorate_env\n",
    "from AgentRun import Arguments, train_and_evaluate\n",
    "from AgentZoo import AgentTD3, AgentSAC, AgentPPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "env_name = 'AntBulletEnv-v0'\n",
    "assert env_name in {\n",
    "    \"AntBulletEnv-v0\", \n",
    "    \"Walker2DBulletEnv-v0\", \n",
    "    \"HalfCheetahBulletEnv-v0\",\n",
    "    \"HumanoidBulletEnv-v0\", \n",
    "    \"HumanoidFlagrunBulletEnv-v0\", \n",
    "    \"HumanoidFlagrunHarderBulletEnv-v0\",\n",
    "}\n",
    "env = gym.make(env_name)\n",
    "env = decorate_env(env, if_print=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Arguments()\n",
    "args.agent_rl = AgentSAC  # AgentSAC can't reach target_reward=2500, try AgentModSAC\n",
    "args.env = env\n",
    "args.reward_scale = 2 ** -3\n",
    "args.break_step = int(1e6 * 8)\n",
    "args.eval_times = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demo 5: Atari game (wait for adding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = 'breakout-v0'  # 'SpaceInvaders-v0'"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
